feat: partial weight sync (delta + selective)#1806
Open
nanjiangwill wants to merge 8 commits into
Open
Conversation
d2aa1c0 to
5fb928d
Compare
23b9059 to
b056c46
Compare
Add a non-colocated weight-sync mode that broadcasts (current - snapshot)
sparse-encoded over NCCL and applies it additively on SGLang, instead of
broadcasting every parameter on every step. At GLM-4.7-355B scale the wire
shrinks ~30x (typical 2-3% density) and the sync phase becomes dominated by
gather+convert rather than broadcast.
Slime side:
- --update-weight-mode {full,delta}: pick the sync strategy.
- --delta-compression {sparse_indices,sparse_bitmask,dense}: wire encoding.
- --delta-dtype {fp16,bf16,fp32}: subtraction + apply happen at this dtype.
- --delta-full-interval N: periodic full sync (first sync is always full).
- --delta-artifact-dir PATH: optional async per-chunk artifact writer.
- Rejects --update-weight-mode delta with --colocate (CUDA IPC has no wire).
- UpdateWeightFromDistributedDelta extends UpdateWeightFromDistributed via
a single _on_chunk hook on the base class (template method); the base
class is otherwise unchanged in behaviour.
SGLang side (slime patch):
- WeightDeltaSpec / WeightDeltaParam / WeightDeltaEncoding wire protocol
in io_struct.
- Receiver decodes lazily (per-param), then applies via load_weights with
Tensor.copy_/fill_ rewired to add_ only inside param storage ranges.
- In-place add_ between bf16 param and fp32 delta auto-promotes for math
and casts back on store, so deltas keep fp32 precision without an extra
cast or scratch allocation.
Adds:
- examples/delta_compression/ with a non-colocated GLM-4.7-355B launcher.
- docs/en/advanced/delta-compression.md + zh translation.
b1c4ae1 to
db09918
Compare
- examples/README.md: drop "at 355B scale" from the general directory listing.
- examples/delta_compression/README.md: replace "(typical 2-3% density at
355B)" with a model-agnostic note about RL fine-tuning density.
- docs/{en,zh}/advanced/delta-compression.md: same de-anchoring in the
overview paragraph; specific numbers stay in result tables only.
- examples/delta_compression/run-glm4.7-355B-A32B-delta.sh: drop unused
SCRIPT_DIR and the delta-flavoured comment on --update-weight-buffer-size
(the flag is general).
…plain lossless Replace every --delta-full-interval 10000 occurrence with 30 (the actual argparse default) and add an inline note at each site: setting the flag to a very large integer (e.g. 10000) effectively disables periodic full syncs, which is fine because with --delta-dtype fp32 the apply is lossless — every bf16 value is exactly representable in fp32, the subtraction captures the exact difference between two stored bf16 values, and the receiver's in-place bf16 += fp32 add reproduces the trainer's bf16 state bit-for-bit on rounding back, so no error accumulates across deltas. Updates: example script, example README, en + zh docs.
…artial*
Add a second partial-update mode 'selective' alongside the existing 'delta'.
Both share the snapshot, sparse encoding, periodic-base-sync, and bucketed
broadcast machinery; they differ in what's on the wire and how the receiver
applies it:
delta — wire = (current − snapshot) at delta_dtype;
receiver: param += delta (in-place add, auto-promotes for fp32
math, casts back to param dtype on store).
selective — wire = new param values at changed positions, with NaN as the
"unchanged" sentinel in the dense decoded tensor;
receiver: param[~isnan(src)] = src[~isnan(src)] (selective
overwrite, leaves NaN positions untouched).
Selective is lossless by construction (no arithmetic), the wire values portion
is ~½ the size of fp32 delta, and the per-element apply is a direct masked
copy. Selective requires float param dtype on the wire (validated at decode);
slime always sends HF-format floats so this holds in practice.
CLI rename: every flag now lives under --update-weight-*. Scope is encoded in
the name itself:
--update-weight-mode {full, delta, selective}
--update-weight-delta-dtype # delta-only
--update-weight-partial-encoding # delta + selective
--update-weight-base-sync-interval # delta + selective
--update-weight-partial-artifact-dir # delta + selective
--sglang-update-weight-partial-chunk-bytes # delta + selective (receiver)
Internal naming follows: WeightDelta* → PartialWeight* on the wire,
UpdateWeightFromDistributedDelta → UpdateWeightFromDistributedPartial,
DeltaSync → PartialSync, _decode_sparse_delta → _decode_sparse_partial,
_additive_load_context stays + new _selective_load_context (they share a
_param_storage_predicate helper + _patched_in_place_writes scaffolding).
File: update_weight_from_distributed_delta.py → ..._partial.py.
The slime orchestrator is one class (UpdateWeightFromDistributedPartial) with
mode branching in _enqueue_partial_chunk; _send_partial_weights body is
shared. The SGLang receiver dispatches on load_format ("delta" or "selective")
to one of two thin entry points that share _update_partial_weights_from_distributed.
Wire schema: single PartialWeightSpec class shared by both modes (structurally
identical wire format); request field is .partial (was .delta).
Docs (en + zh) and the example script use the new flag names and document
selective mode without perf numbers (pending experiment).
The feature now covers two peer modes (delta + selective), so 'delta
compression' is misleading at the umbrella level. Rename to 'partial weight
sync':
docs/{en,zh}/advanced/delta-compression.md → partial-weight-sync.md
examples/delta_compression/ → examples/partial_weight_sync/
examples/.../run-glm4.7-355B-A32B-delta.sh → ..._partial.sh
Internal references and index.rst entries updated. The example script and
README now show two explicit PARTIAL_ARGS blocks (delta active by default,
selective commented out) so users can flip modes by swapping which block is
uncommented.
Adds inspiration / prior-art references (Cursor Composer 2 + Fireworks AI
for delta; arXiv:2509.19128 for selective) and a placeholder for selective
W&B traces. Also documents that this feature only changes what bytes ship
on the wire, so any future communication-layer optimization in slime stacks
additively on top.
…sh sglang.patch - `--update-weight-base-sync-interval` default: 30 → 9999. Both partial modes are lossless under their defaults (delta with fp32 math, selective by construction), so periodic base syncs aren't needed for correctness. Help text and docs/examples updated to explain the override (e.g. 30 to verify against periodic full broadcasts). - Refresh docker/patch/latest/sglang.patch with the renamed CLI / wire types. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Trainer-side observability for partial weight sync. New wandb metrics on the rollout/step axis, emitted by UpdateWeightFromDistributedPartial: perf/update_weights_density fraction of positions that moved perf/update_weights_is_base_sync 0/1 flag (lets you disambiguate spikes) perf/update_weights_wire_bytes bytes actually shipped per sync Plumbing: weight updaters gain an update_weight_metrics dict + pop_metrics(). The actor drains it via log_perf_data(..., extra_metrics=...), mirroring slime's existing rollout_extra_metrics pattern in _log_rollout_data. Other changes: - nnz field on PartialChunk so density accounting is symmetric with byte_size - Reorder selective before delta in all prose/tables/help text/example script (example script now defaults to selective; delta block commented out) - README placeholder for the density plot under the Selective mode results section, with a note explaining step 0 is omitted (always base-sync = 1.0) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…unify modes The compute step now returns (payload, mask) per param via a new PartialPayload dataclass; the encoder reads the mask directly instead of re-deriving it from a sentinel. NaN no longer materialized on the sender in the common sparse paths — only inside _encode_dense for selective (debug-only encoding). - New PartialPayload(name, payload, mask) carrying per-param compute output. - compute_delta + compute_selective merged into one compute_payload(mode); the per-mode logic lives in a closure dispatched once before the shared loop. - _encode_sparse reads (pp.payload, pp.mask) directly; no predicate threading. - _make_indices_kv / _make_bitmask_kv flattened to plain _indices_kv / _bitmask_kv (factories were vestigial once is_active capture was removed). - Dense encoding extracted into _encode_dense; lazily NaN-marks selective tensors for the receiver-side sentinel — the only sender path that still touches NaN. Net: selective saves one full-size tensor allocation per param on the sender (no more torch.where(mask, tensor, NaN) materialization in the common case). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
a99a867 to
4245bfa
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Non-colocated weight sync that broadcasts only the changed-position payload over NCCL instead of full weights. Two modes (peer; pick at runtime):
selective— broadcast new param values at changed positions only (NaN as the "unchanged" sentinel); receiver overwrites those positions, leaves others alone. Lossless by construction (no arithmetic), wire values portion ~½ the size of fp32 delta. Inspired by arXiv:2509.19128.delta— broadcast(current − snapshot)sparse-encoded; receiver applies additively. Lossless with--update-weight-delta-dtype fp32(default). Inspired by Cursor Composer 2 and Fireworks AI — Frontier RL Is Cheaper Than You Think.Measured on GLM-4.7-355B-A32B non-colocated (8 actor + 8 rollout nodes, 64 rollout H100s) in delta mode: wire shrinks ~30× (~170 GB → ~5.9 GB) and the broadcast stops dominating the sync phase. Per-sync density logged at 2–3%, which sits below the
1/32 ≈ 3.125%break-even, sosparse_indicesis the right encoding for this workload.docs/en/advanced/partial-weight-sync.md·docs/zh/advanced/partial-weight-sync.mdexamples/partial_weight_sync/run-glm4.7-355B-A32B-partial.sh(two pre-builtPARTIAL_ARGSblocks: selective active by default, delta commented out)docker/patch/latest/sglang.patch; applied during docker image build.CLI surface
Trainer side (slime):
--update-weight-modefull/selective/delta--update-weight-delta-dtypefp16/bf16/fp32--update-weight-partial-encodingsparse_indices/sparse_bitmask/dense--update-weight-base-sync-interval--update-weight-partial-artifact-dirSGLang side (auto-mirrored via
--sglang-prefix):--sglang-update-weight-partial-chunk-bytesmodel.load_weightscall on apply--update-weight-mode={selective,delta}is rejected with--colocate— CUDA IPC has no wire to compress.Code shape
Slime:
UpdateWeightFromDistributedPartialextendsUpdateWeightFromDistributedvia a single_on_chunkhook on the base class (Template Method). Base behaviour is otherwise unchanged. One subclass for both modes; mode branching lives in_enqueue_partial_chunk(compute_selective vs compute_delta) and the load_format string passed to the receiver.SGLang patch (in
docker/patch/latest/sglang.patch):PartialWeightSpec+PartialWeightEncoding+PartialWeightParam. Receiver dispatches onload_format("selective"vs"delta")._update_partial_weights_from_distributed(..., mode)that picks the fill_value (NaN vs 0) and apply context (selective vs additive) inline. Decode is lazy per-param, peak HBM bounded byencoded_buffers + chunk_byte_cap._selective_load_context(selective) and_additive_load_context(delta), sharing a_param_storage_predicate+_patched_in_place_writeshelper. Both rewireTensor.copy_/fill_scoped to writes whose destination is inside model param storage (bisect on captureddata_ptrranges).param[~isnan(src)] = src[~isnan(src)]— overwrites changed positions, leaves NaN-marked positions untouched.add_auto-promotes bf16+=fp32 to fp32 math and casts back on store, so deltas keep fp32 precision without an explicit cast.post_load_weights(DeepSeek'sw_kc/w_vcmaterialization) is wrapped to run with unmodifiedcopy_/fill_, so derived tensors overwrite correctly under both contexts.Results
sparse_indices)sparse_indices)Selective mode
Pending — experiment running. The example README has a placeholder section for selective traces and per-sync density/wall-clock numbers; will be filled in once the run completes.
Delta mode
W&B traces are in
examples/partial_weight_sync/README.md(raw_reward / train-rollout logprob abs diff / update_weights_time).Why this composes with future communication work
The feature only changes what bytes get shipped. The NCCL broadcast, Ray lock, bucket scheduling, and send/receive layers are untouched. Any future slime improvement to the weight-update communication path (better compute/broadcast overlap, pipeline-parallel sends, NIC-level tricks) stacks additively on top of the speedups here — both modes inherit it for free.
🤖 Generated with Claude Code